import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class Leetcode203 {

    public ListNode removeElements(ListNode head, int val) {
        ListNode res = head;
        while (res != null && res.val == val) {
            res = res.next;
        }

        if (res == null) {
            return null;
        }

        ListNode parent = res;
        ListNode next1 = res.next;
        while (next1 != null) {
            if (next1.val == val) {
                parent.next = next1.next;
            }


        }

        return res;

    }
}
